Flash Attention

Paper reading May 2025

Rostislav Stoyanov

Attention

  • Self-attention is a key component of the Transformer architecture.

  • However, it scales quadratically with the sequence length, making it inefficient for long sequences.

  • There have been various attempts to improve the efficiency of attention mechanisms making them linear or near-linear.

Flash Attention

  • Flash attention [1], [2] shows that the limiting factor in the performance of attention is the memory bandwidth.

  • We can reduce computation time by making the implemetation more I/O aware.

Speedup over the PyTorch implementation of attention on GPT-2. [1]

Flash Attention 2

Attention forward and backwards pass on A100 GPU. [2]

Table of content

Attention

  • Flash attention relies on a technique called “tiling” to break down the attention computation into smaller chunks that fit into the GPU’s memory.

  • This allows for efficient computation without the need for large intermediate tensors.

  • However, for tiling to work we need to have asociative operations, which is not the case for normal attention mechanism, as softmax is not associative.

Attention

  • The softmax function is a key component of the attention mechanism.

  • It is used to compute the attention weights, which scale the values based on the similarity of the queries and keys. \[ \text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^\mathsf{T}}{\sqrt{d_k}}) V, \text{where} \\ Q, K, V \in \mathbb{R}^{N \times d_k}, \\ N \text{ is the sequence length,}\\ d_k \text{ is the dimension of the heads.} \]

Softmax

  • The softmax function is defined as for a vector \(x = \{x_i\}_{i=1}^N\in \mathbb{R}^N\) as follows: \[ \text{Softmax}(x) = \left\{\frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}\right\}_{i=1}^N\]

  • If any of the \(x_i\) is very large, the exponentials will overflow and the softmax will return NaN.

  • The limit for float32 is \(3.4028235 e^{38}\), which means that the softmax will overflow if any of the \(x_i\) is larger than \(88.722839\).

Safe softmax

  • In order to mitigate this the followint trick is used: \[ \text{Softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}} = \\ \frac{C}{C}\frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}} = \\ \frac{C e^{x_i}}{\sum_{j=1}^{N} C e^{x_j}} = \frac{e^{x_i + log(C)}}{\sum_{j=1}^{N} e^{x_j + log(C)}},\\ \text{where typically } log(C) = -\max_{i} x_i \]

Safe softmax

3 pass implementation

  • The naive implementation requires three passes over the data.

  • Lets introduce the following notation:

    • \(\left\{m_{i}\right\}: \max _{j=1}^{i}\left\{x_{j}\right\}\), with initial value \(m_{0}=-\infty\).
    • \(\left\{d_{i}\right\}: \sum_{j=1}^{i} e^{x_{j}-m_{N}}\), with initial value \(d_{0}=0\).
    • \(\left\{a_{i}\right\}\) : the final softmax value.

Safe softmax

3 pass implementation

  1. For \(i \leftarrow 1, N\) do \[ m_{i} \leftarrow \max \left(m_{i-1}, x_{i}\right) \qquad(1)\]

  2. For \(i \leftarrow 1, N\) do \[ d_{i} \leftarrow d_{i-1}+e^{x_{i}-m_{N}} \qquad(2)\]

  3. For \(i \leftarrow 1, N\) do \[ a_{i} \leftarrow \frac{e^{x_{i}-m_{N}}}{d_{N}} \qquad(3)\]

Online softmax

  • We want to fuse the operations in a single loop, however (Equation 2) and (Equation 3) cannot be fused as (Equation 3) depends on the value of \(m_N\) which is not known until the end of the loop.

  • We can create a surrogate sequence \(d'_{i}\) that is computed in the same way as \(d_i\) but does not depend on \(m_N\): \[ d'_{i} = \sum_{j=1}^i e^{x_j-m_i}. \]

  • Furthermore \(d_{N} = d'_{N}\), so we can replace \(d_N\) with \(d'_N\) in (Equation 3).

OnlineSoftmax

  • We can find a recurence relation between \(d'_{i}\) and \(d'_{i-1}\): \[ d'_{i} \overset{\text{def}}{=} \sum_{j=1}^i e^{x_j-m_i} \overset{\text{unroll}}{=} \left(\sum_{j=1}^{i-1} e^{x_j-m_i}\right) + e^{x_i-m_i} \\ = \left(\sum_{j=1}^{i-1} e^{x_j-m_i+m_{i-1}-m_{i-1}}\right) + e^{x_i-m_i}\\ = \left(\sum_{j=1}^{i-1} e^{x_j-m_{i-1}}\right)e^{m_{i-1}-{m_{i}}} + e^{x_i-m_i} \\ = d'_{i-1} e^{m_{i-1}-m_i} + e^{x_i-m_i} \]

OnlineSoftmax

2-pass implementation

  1. For \(i \leftarrow 1, N\) do \[ m_{i} \leftarrow \max \left(m_{i-1}, x_{i}\right)\\ d'_i \leftarrow d'_{i-1} e^{m_{i-1}-m_i} + e^{x_i-m_i} \qquad(4)\]

  2. For \(i \leftarrow 1, N\) do \[ a_{i} \leftarrow \frac{e^{x_{i}-m_{N}}}{d'_{N}} \qquad(5)\]

Multi-pass self-attention

  • While we can’t decrease the number of passes for softmax we can decrease the number of passes for the self-attention mechanism, by finding a one-pass recuurrence relation for the O matrix.

  • \(Q[k,:]\) is the k-th row vector of the \(Q\) matrix.

  • \(K^{T}[:, i]\) : the \(i\)-th column vector of \(K^{T}\) matrix.

  • \(O[k,:]\) : the \(k\)-th row of output \(O\) matrix.

  • \(V[i,:]\) : the \(i\)-th row of \(V\) matrix.

  • \(\left\{\boldsymbol{o}_{i}\right\}: \sum_{j=1}^{i} a_{j} V[j,:]\), a row vector storing partial aggregation result \(A[k,: i] \times V[: i,:]\)

Multi-pass self-attention

  1. For \(i \leftarrow 1, N\) do

    \[\begin{aligned} x_{i} & \leftarrow Q[k,:] K^{T}[:, i] \\ m_{i} & \leftarrow \max \left(m_{i-1}, x_{i}\right) \\ d_{i}^{\prime} & \leftarrow d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}+e^{x_{i}-m_{i}} \end{aligned}\]
  2. For \(i \leftarrow 1, N\) do \[ a_{i} \leftarrow \frac{e^{x_{i}-m_{N}}}{d_{N}^{\prime}} \qquad(6)\] \[ \boldsymbol{o}_{i} \leftarrow \boldsymbol{o}_{i-1}+a_{i} V[i,:] \qquad(7)\]

  3. \[ O[k,:] \leftarrow \boldsymbol{o}_{N} \]

Recurrence relation for \(O\)

  • Lets replace the \(a_i\) in (Equation 6) with the recurrence relation from (Equation 7): \[ \boldsymbol{o}_{i}:=\sum_{j=1}^{i}\left(\frac{e^{x_{j}-m_{N}}}{d_{N}^{\prime}} V[j,:]\right) \]

  • Depends on \(m_N\) and \(d_N\), so again create a surrogate sequence \(\boldsymbol{o}'_{i}\): \[ \boldsymbol{o}_{i}^{\prime}:=\left(\sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:]\right) \]

Recurrence relation for \(O\)

\[\begin{align*} \boldsymbol{o}_{i}^{\prime} & =\sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:] \\ & =\left(\sum_{j=1}^{i-1} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:]\right)+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \\ & =\left(\sum_{j=1}^{i-1} \frac{e^{x_{j}-m_{i-1}}}{d_{i-1}^{\prime}} \frac{e^{x_{j}-m_{i}}}{e^{x_{j}-m_{i-1}}} \frac{d_{i-1}^{\prime}}{d_{i}^{\prime}} V[j,:]\right)+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \\ & =\left(\sum_{j=1}^{i-1} \frac{e^{x_{j}-m_{i-1}}}{d_{i-1}^{\prime}} V[j,:]\right) \frac{d_{i-1}^{\prime}}{d_{i}^{\prime}} e^{m_{i-1}-m_{i}}+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \\ \end{align*}\]

Recurrence relation for \(O\)

\[\begin{align*} \boldsymbol{o}_{i}^{\prime} & =\sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:] \\ & =\boldsymbol{o}_{i-1}^{\prime} \frac{d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}}{d_{i}^{\prime}}+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \end{align*}\]

Flash attention

Single pass self-attention

  1. For \(i \leftarrow 1, N\) do \[\begin{aligned} x_i & \leftarrow Q[k,:] K^T[:, i] \\ m_i & \leftarrow \max \left(m_{i-1}, x_i\right) \\ d_i^{\prime} & \leftarrow d_{i-1}^{\prime} e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ \boldsymbol{o}_i^{\prime} & \leftarrow \boldsymbol{o}_{i-1}^{\prime} \frac{d_{i-1}^{\prime} e^{m_{i-1}-m_i}}{d_i^{\prime}}+\frac{e^{x_i-m_i}}{d_i^{\prime}} V[i,:] \end{aligned}\]
  2. \[ O[k,:] \leftarrow \boldsymbol{o}_{N}^{\prime}\]

Flash attention 2

  • It turns out we don’t need to divide by by \(d_N'\) every time, we can just divide by it once at the end.
  • If we look at (Equation 6) we can see we are always dividing \(a_i\) by \(d'_N\), so we can just divide the final result by \(d'_N\).
  • We can create a surrogate sequence \(\boldsymbol{s}_{i}\) that is computed in the same way as \(o'_i\) but does not depend on \(d'_N\): \[ \boldsymbol{s}_{i} :=\left(\sum_{j=1}^{i} e^{x_{j}-m_{i}} V[j,:]\right) \]

Flash attention 2

  1. For \(i \leftarrow 1, N\) do \[\begin{aligned} x_i & \leftarrow Q[k,:] K^T[:, i] \\ m_i & \leftarrow \max \left(m_{i-1}, x_i\right) \\ d_i^{\prime} & \leftarrow d_{i-1}^{\prime} e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ \boldsymbol{s}_i^{\prime} & \leftarrow \boldsymbol{s}_{i-1}^{\prime} e^{m_{i-1}-m_i}+{e^{x_i-m_i}} V[i,:] \end{aligned}\]
  2. \[ O[k,:] \leftarrow \frac{\boldsymbol{s}_{N}^{\prime}}{d_N^{\prime}}\]

GPU programming

  • The single pass attention allows us to use tiling as we don’t need to store large intermediate tensors.
  • In order to understand how tiling works and why it is useful we need some knowledge of GPU programming.

GPU programming

CPU vs GPU

  • CPU is optimized for low latency and high single-threaded performance, while GPU is optimized for high throughput and parallelism.

The GPU Devotes More Transistors to Data Processing [3]

GPU programming

Running code on CPU

Running code on CPU (Taken from CS149 Stanford)

GPU programming

Running code on GPU

Running code on GPU(Taken from CS149 Stanford)

CUDA programming

Thread hierarchy

CUDA thread hierarchy (Taken from CS149 Stanford)

CUDA programming

Memory model

CUDA memory model (Taken from CS149 Stanford)

CUDA programming

Memory model

CUDA types of memory (Taken from CS149 Stanford)

CUDA programming

Memory types

CUDA memory types (Taken from yt)

CUDA programming

Memory types

CUDA memory types (Taken from yt)

CUDA programming

Memory types

CUDA memory types [4]

CUDA programming

Synchronization

CUDA synchronization (Taken from CS149 Stanford)

Tiling

Basic matrix multiplication

Small matrix multiplication example [4]

Basic matmul kernel [4]

Tiling

Issues with basic matrix multiplication

Thread memory access [4]

Tiling

Tiled matrix multiplication

Breaking down the matrix into tiles [4]

Tiling

Tiled matmul execution order

Tiled execution order [4]

Tiling

Tiled matrix multiplication kernel

Tiled matmul kernel [4]

Flash Attention

Tiled attention

Flash attention computation on GPU [5]

Flash Attention

Another graphical representation of Flash Attention [2]

Flash Attention

Algorithm forward pass

Flash attention algorithm [1]

Flash Attention 2

  • Flash Attention 2 is very similar, however two tricks are applied:
    • Rescaling is done only once at the end, instead of at every step.
    • No need to keep both maximum \(m^j\) and sum of exponentials \(l^{j}\) instead store \(L^{j}:= m^j + log(l^j)\)

Flash Attention 2

Algorithm forward pass

Flash attention 2 algorithm [2]

Flash Attention 2

  • There are some additional improvements added in Flash Attention 2:
    • Worker parralelization is done not only over batch size but also over sequence length.
    • Work paritiionining between warps is changed to reduce memory read/writes.
    • More details in paper [2].

Conclusion

  • Flash attention is a highly efficient implementation of the attention mechanism that reduces memory bandwidth usage and computation time.
  • It achieves this by reformulating the softmax and attention computations to allow for tiling and single-pass execution.
  • Flash attention 2 further improves the efficiency by rescaling only once at the end and optimizing memory access patterns.

References

[1]
T. Dao, D. Y. Fu, S. Ermon, A. Rudra, and C. Ré, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv, Jun. 2022. doi: 10.48550/arXiv.2205.14135.
[2]
T. Dao, “Faster Attention with Better Parallelism and Work Partitioning,” Available: https://tridao.me/publications/flash2/flash2.pdf
[3]
“1. IntroductionCUDA C++ Programming Guide.” Accessed: May 25, 2025. [Online]. Available: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#
[4]
W. H. Wen-Mei, D. B. Kirk, and I. El Hajj, Programming massively parallel processors: A hands-on approach. Morgan Kaufmann, 2022.
[5]
Z. Ye, “From Online Softmax to FlashAttention,” Available: https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf